{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Introduction to Transformers\n",
    "\n",
    "In this notebook, we\"re going to install a transformer model, analyze the embedding output, and compare some vectors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.10/site-packages/sentence_transformers/cross_encoder/CrossEncoder.py:13: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n",
      "  from tqdm.autonotebook import tqdm, trange\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append(\"../..\")\n",
    "from aips import *\n",
    "import pandas\n",
    "import numpy\n",
    "import sentence_transformers\n",
    "from IPython.display import display, HTML"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Listing 13.5\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a506a145dab74d5ca415cca086c9685b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "modules.json:   0%|          | 0.00/229 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6da4e060bdf44236aba1bb62e0b7c95c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "config_sentence_transformers.json:   0%|          | 0.00/122 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f5ad65b2b53a4462b69f15f4fa93a297",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "README.md:   0%|          | 0.00/4.03k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9ca838ce44f44ac3bbfb45db8d62b415",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "sentence_bert_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "107fcfd2ac3f4b1b916a350e57997db5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "config.json:   0%|          | 0.00/688 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c22b5e9bfe584d10b6c70a259fa2eeaa",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model.safetensors:   0%|          | 0.00/499M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "396f9a7459d34fada262bee7f533dc13",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer_config.json:   0%|          | 0.00/334 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "18012e2845b94a7394b9f38e15ba21e3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d8a72e10c785463195feaf5979f8c2a1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "457ecf7892fe41499e1396c3702315f1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8266ae1cf6d040faac6fe37f100370da",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "added_tokens.json:   0%|          | 0.00/2.00 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "269008b9e7c94bc09fa1812e83f094e0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "78d9198558fa409f9c5100019cf99dba",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from sentence_transformers import SentenceTransformer\n",
    "transformer = SentenceTransformer(\"roberta-base-nli-stsb-mean-tokens\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Listing 13.6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of embeddings: 4\n",
      "Dimensions per embedding: 768\n",
      "The embedding feature values of \"it's raining hard\":\n",
      "tensor([ 1.1609e-01, -1.8422e-01,  4.1023e-01,  2.8474e-01,  5.8746e-01,\n",
      "         7.4418e-02, -5.6910e-01, -1.5300e+00, -1.4629e-01,  7.9517e-01,\n",
      "         5.0953e-01,  3.5076e-01, -6.7288e-01, -2.9603e-01, -2.3220e-01,\n",
      "         4.8475e-01,  9.9531e-01,  6.1437e-01,  7.8995e-02, -7.7781e-01,\n",
      "         1.0021e+00,  3.5468e-01, -1.7309e-01,  3.9410e-01,  1.6540e-01,\n",
      "        -1.2335e-01,  6.1811e-01,  3.6482e-01,  3.2900e-01,  1.0812e+00,\n",
      "        -5.4269e-01, -2.2409e-01, -1.4409e+00,  8.2625e-01, -1.1814e+00,\n",
      "        -4.4101e-01,  5.8892e-01, -1.5328e+00, -2.9688e-01, -6.7420e-03,\n",
      "         4.8688e-01,  1.0616e-01, -2.7084e-01,  2.4105e-02,  2.8154e-01,\n",
      "        -2.0851e-01, -2.3183e-01, -1.0750e+00, -6.3038e-01, -7.4111e-02,\n",
      "         3.0137e-01, -1.2426e+00, -2.5925e-01, -3.2736e-02,  9.0476e-01,\n",
      "         1.0643e-02,  1.2886e-01,  8.2835e-01, -1.5106e+00, -4.7442e-01,\n",
      "         1.5537e+00, -8.3725e-01, -2.5802e-01,  7.2701e-01, -2.1782e-01,\n",
      "         4.3834e-01,  8.1870e-01,  2.1506e+00,  7.7013e-02,  3.9540e-01,\n",
      "         3.5280e-01,  1.5424e-01,  1.1395e+00, -7.5375e-01, -5.7804e-01,\n",
      "        -1.4815e-01, -4.7606e-01,  1.5028e+00,  5.2689e-01, -4.3875e-01,\n",
      "         1.0173e-01, -1.3668e-02, -8.6674e-01, -5.1597e-01, -8.8089e-01,\n",
      "        -8.8594e-01,  1.6453e+00,  1.6563e-01,  1.5489e-01, -4.6433e-01,\n",
      "         2.9270e-01,  5.5046e-01,  1.1388e+00, -2.2292e-02, -1.9230e-01,\n",
      "         2.5127e-01, -2.6604e-01, -1.6304e-01,  4.6697e-01,  6.4704e-01,\n",
      "         2.3353e-01, -5.0705e-01,  6.7264e-01,  5.3285e-01, -7.9935e-01,\n",
      "         1.3630e+00, -6.5025e-01,  2.6927e-01, -9.6264e-01,  7.0896e-01,\n",
      "         4.5403e-01, -1.2476e+00,  5.8096e-02,  7.3235e-01, -1.7337e-01,\n",
      "         5.0296e-02, -1.6796e-01, -1.4571e-01, -5.2013e-01, -9.9281e-01,\n",
      "        -8.0377e-01, -9.8628e-02,  5.9392e-01,  5.8502e-01, -1.0408e+00,\n",
      "         3.4451e-02, -3.6206e-02, -5.1456e-01,  4.8950e-01,  1.1598e-01,\n",
      "        -5.3393e-01,  6.8377e-02,  1.2736e-01,  3.8780e-01, -1.4060e-01,\n",
      "         3.3287e-01,  4.9323e-01, -5.7388e-01, -9.9959e-01, -1.2496e+00,\n",
      "        -1.2059e-01,  8.1342e-01,  1.1275e+00,  1.3159e+00, -8.2315e-01,\n",
      "        -5.2154e-01,  1.1336e+00,  6.5252e-01,  1.8138e-02, -2.3254e-01,\n",
      "         1.6957e-01,  1.1297e+00,  1.2883e-02,  6.4734e-01, -2.3145e-01,\n",
      "        -6.0603e-01, -1.3832e+00,  1.1276e-01, -4.2431e-01, -9.0949e-02,\n",
      "         2.0123e+00,  3.4952e-01,  2.4911e-01,  6.4084e-01, -6.8175e-01,\n",
      "         6.8169e-01, -1.1436e+00,  7.1149e-01,  7.5227e-01,  1.7830e-01,\n",
      "         2.1357e-01,  7.4176e-01,  5.1335e-01,  3.8950e-01,  8.2767e-01,\n",
      "        -4.3359e-01,  1.6286e+00, -6.5435e-01, -2.5960e-01, -9.0844e-01,\n",
      "         1.1767e-01, -7.7510e-02,  8.5900e-02,  2.1811e-01,  2.9793e-01,\n",
      "        -1.3657e+00,  5.8282e-01,  1.2325e-01,  4.2484e-01,  4.4059e-01,\n",
      "        -2.3504e-01,  4.9956e-02,  4.8132e-01, -1.7567e-02,  1.1849e+00,\n",
      "         1.0027e-01,  3.0968e-01,  2.8310e-01,  2.1691e-01,  4.4285e-01,\n",
      "        -7.1970e-01, -4.4008e-01, -9.0192e-01, -1.5634e-01,  1.3950e+00,\n",
      "        -5.1620e-01, -9.4348e-01,  1.2406e-01,  1.0990e-01,  8.6775e-01,\n",
      "         2.9539e-02,  8.7526e-04,  5.0737e-01,  1.7828e-01, -1.2491e+00,\n",
      "         5.8967e-01,  5.2507e-02,  4.9740e-01,  6.9696e-01,  2.9045e-02,\n",
      "        -3.0555e-01, -1.5936e-01,  4.5473e-02,  8.1945e-01, -4.0761e-01,\n",
      "        -1.9063e-01,  5.1713e-02, -4.2057e-01, -1.0085e+00, -3.8621e-01,\n",
      "        -2.0892e-02, -3.9909e-01, -1.3909e-01,  4.4303e-01,  1.0679e+00,\n",
      "        -8.4159e-01, -1.0894e-01,  4.3101e-01, -1.6907e-01,  1.1464e-01,\n",
      "        -8.0249e-01, -2.8947e-01,  6.9443e-01,  1.7164e-01,  6.9954e-01,\n",
      "        -5.8951e-01,  5.2774e-01,  3.8727e-01, -1.6565e+00, -1.2911e-01,\n",
      "        -4.6673e-01, -1.4279e+00,  4.9954e-01, -4.8093e-01,  9.0127e-01,\n",
      "         1.0629e+00, -2.5993e-01, -1.0712e+00, -1.6172e-01, -1.6310e+00,\n",
      "        -1.1765e-01,  1.0082e+00,  8.5972e-01,  5.8151e-01,  1.0919e-01,\n",
      "         1.0159e+00,  6.3072e-01,  8.7787e-02,  2.8210e-01, -4.0051e-02,\n",
      "        -1.2063e-01, -1.1931e-01, -1.0029e+00,  3.0189e-01,  3.1638e-02,\n",
      "        -2.6200e-01,  7.8046e-01, -1.0382e+00,  2.7122e-01, -4.3240e-01,\n",
      "         1.0977e+00, -6.3824e-01, -1.2395e+00, -8.0311e-01, -6.8329e-01,\n",
      "         8.4686e-01, -4.4402e-01,  5.1465e-01,  4.0345e-01,  5.2973e-02,\n",
      "        -3.5895e-01,  3.6769e-01, -7.2966e-01,  6.3846e-02, -2.5491e-01,\n",
      "        -1.3762e+00,  9.5290e-01,  1.0384e+00,  6.2277e-01, -5.6496e-01,\n",
      "         1.0850e+00,  1.2306e-01,  1.3831e+00, -4.4560e-02, -7.6367e-01,\n",
      "        -9.4729e-01,  5.0536e-01,  2.4840e-01, -1.7877e-01,  4.2302e-01,\n",
      "        -1.3719e-01,  1.0730e+00,  8.3146e-02, -5.3421e-01,  8.6437e-01,\n",
      "         2.9190e-01,  4.2023e-01,  7.6190e-01,  1.7406e-01, -2.1164e-01,\n",
      "         2.3279e-01,  4.2986e-01, -4.4564e-02,  7.7100e-01,  2.0234e-01,\n",
      "        -3.1252e-01,  3.0249e-01, -1.8905e-01, -3.4307e-01,  1.1294e+00,\n",
      "        -5.4007e-02,  1.0806e+00,  8.5194e-01, -6.3685e-01,  9.0963e-01,\n",
      "         1.5040e+00, -1.6427e-01, -3.1725e-01,  9.1774e-01,  1.0272e-01,\n",
      "        -4.3961e-01, -1.3185e+00,  3.9763e-01,  1.2496e-01,  1.0807e-01,\n",
      "         3.4581e-01, -1.3077e-02,  3.6709e-01,  2.1116e-01, -8.1173e-01,\n",
      "         2.6391e-01, -6.8757e-02, -1.1729e+00,  2.8510e-01, -6.2180e-01,\n",
      "        -1.6956e-01,  6.8330e-02,  7.7788e-02,  4.4313e-01,  1.3177e+00,\n",
      "        -6.5941e-01, -6.9051e-01,  3.1055e-01,  9.3960e-01,  8.7399e-02,\n",
      "         5.2723e-03, -1.7486e+00, -2.7384e-02,  4.4814e-01,  8.8105e-01,\n",
      "         5.0114e-01, -7.3723e-01,  1.2153e+00,  6.1098e-01, -8.1606e-01,\n",
      "        -1.5037e+00,  8.2924e-02,  9.8820e-01,  1.3318e+00, -1.0139e+00,\n",
      "         1.3132e+00,  1.1753e+00,  1.1210e-01,  1.6352e+00,  1.3790e-01,\n",
      "        -3.5416e-01,  3.5884e-01, -4.0405e-01, -7.9753e-01, -1.9505e-01,\n",
      "         7.7268e-01, -2.5183e-01,  8.5426e-01,  3.9385e-01,  1.0010e+00,\n",
      "        -4.4580e-01, -5.3831e-01,  8.3262e-01, -1.2548e-01, -5.6404e-01,\n",
      "        -1.0432e+00, -4.8328e-01, -2.3556e-01,  4.8483e-01,  4.3271e-01,\n",
      "         1.7186e+00, -1.1758e-02, -2.6412e-01, -7.9841e-01,  1.1218e-01,\n",
      "        -3.4835e-02,  4.8928e-01,  2.2770e-01,  2.2314e-01, -9.5641e-01,\n",
      "        -7.9357e-01,  1.4441e+00,  2.5969e-01,  3.4111e-01,  9.2850e-02,\n",
      "        -7.6579e-01, -4.1646e-01,  2.3366e-01, -4.4581e-01, -9.4214e-01,\n",
      "         7.5885e-01,  1.3728e-01, -4.9241e-01,  1.0355e+00,  1.9274e-01,\n",
      "        -4.5514e-01,  2.4399e-01,  4.8954e-01, -6.3986e-01, -7.7669e-01,\n",
      "         1.1575e+00,  3.3725e-01,  1.3426e-01,  5.0415e-01,  7.9946e-01,\n",
      "        -1.6317e+00,  5.1542e-01,  2.0249e-01, -1.7069e-01, -7.8957e-01,\n",
      "        -1.6393e-01, -7.0403e-01,  5.8401e-01, -2.0517e-01, -1.3159e-01,\n",
      "         7.1221e-01, -6.2001e-01,  6.0006e-01, -1.6503e+00,  4.3350e-01,\n",
      "        -7.7825e-01,  7.4609e-01,  1.1532e-01,  4.0220e-01,  8.3949e-01,\n",
      "         1.6690e-02,  1.4298e+00, -4.5481e-01, -2.9181e-01, -1.5218e-01,\n",
      "         4.7675e-01,  7.7750e-01,  1.1742e-01,  6.6526e-01, -2.0501e-01,\n",
      "         7.9335e-01, -1.0084e+00,  1.0591e-01,  4.9006e-01,  3.3001e-01,\n",
      "         8.2972e-01, -1.0422e+00, -2.8397e-01, -3.7651e-01, -5.4660e-02,\n",
      "         1.4490e-01, -7.4660e-01,  2.9423e-01,  2.1471e-01,  5.0403e-01,\n",
      "         4.6970e-01,  3.7982e-02, -6.4309e-01, -5.7236e-01, -3.2481e-02,\n",
      "         1.3093e+00, -9.9847e-01,  7.9114e-01, -3.1830e-01, -1.0495e-02,\n",
      "         4.2059e-03,  2.0767e-01,  1.3358e+00, -1.7283e+00, -9.6841e-02,\n",
      "        -1.1511e+00, -3.5146e-01, -2.4143e-01, -6.9140e-01, -6.6710e-01,\n",
      "         9.6968e-01,  1.3339e-01, -3.8996e-01, -1.2867e+00,  1.7826e-01,\n",
      "        -8.8286e-01, -1.4692e-01, -2.0821e-02,  1.0275e+00,  2.8322e-01,\n",
      "        -9.3601e-01, -1.9263e-01, -4.3183e-02,  5.4718e-01, -6.3018e-01,\n",
      "        -1.9872e-01, -4.5085e-01,  5.9457e-01,  6.2281e-01, -1.3971e+00,\n",
      "         7.0372e-01,  1.5786e-01, -3.3566e-01,  4.8760e-02,  3.3472e-01,\n",
      "         6.6835e-01, -1.3775e-01,  2.2838e-01,  2.1115e-01,  3.4478e-01,\n",
      "         1.6079e+00,  8.7040e-01, -6.7027e-01, -3.1703e-01, -2.2688e-01,\n",
      "         9.4651e-01,  1.0945e-01,  1.0155e+00, -3.4992e-01, -3.5296e-01,\n",
      "        -1.8450e-01, -9.7254e-01,  2.9016e-01,  4.9783e-01, -3.4826e-01,\n",
      "         6.7376e-01, -2.3551e+00, -5.6917e-01,  1.5061e+00, -5.6225e-01,\n",
      "        -4.2600e-03, -9.1920e-01,  9.1233e-02, -2.0019e-01, -7.6966e-01,\n",
      "        -5.1956e-02, -1.2920e+00, -2.1450e-01, -7.4761e-02,  1.8817e-01,\n",
      "        -4.5332e-01, -2.2401e-01, -1.5978e-01,  6.5755e-01, -6.0412e-01,\n",
      "        -1.4221e+00,  4.1427e-01, -1.6817e+00,  4.8986e-01,  6.7391e-01,\n",
      "         2.7911e-01, -6.0357e-02,  9.6674e-01, -5.5019e-01, -8.6601e-01,\n",
      "         3.5937e-01, -7.2869e-01, -8.7914e-01, -9.9928e-02, -1.2486e-01,\n",
      "        -3.3731e-01,  5.5885e-01, -8.1328e-02,  2.6793e+00, -8.8795e-01,\n",
      "         1.1933e-01,  3.1190e-01,  4.0466e-01, -4.7203e-01, -2.3791e-01,\n",
      "        -1.3356e-01,  1.6696e-01,  9.5303e-01,  5.6724e-01, -7.5027e-02,\n",
      "        -7.6032e-01,  8.3296e-01,  1.0127e-01, -9.7102e-01,  1.4432e-01,\n",
      "        -7.5013e-01,  1.4801e+00, -3.5271e-01,  1.0573e+00,  1.8652e+00,\n",
      "        -3.2036e-02,  3.5752e-01, -1.0835e+00,  8.1300e-01,  1.1802e+00,\n",
      "        -2.1574e-01, -6.0806e-02, -4.3312e-01, -4.7603e-02, -1.1378e-01,\n",
      "         9.2245e-01,  7.1005e-01,  3.0241e-01,  5.0147e-01,  1.3727e-01,\n",
      "         6.4281e-01, -2.5700e-01, -1.9329e-01,  2.1251e-01,  8.6884e-01,\n",
      "        -9.2493e-01, -1.9023e+00,  6.0674e-01,  3.2217e-01, -3.0785e-03,\n",
      "        -2.9099e-01,  1.0429e-01,  3.7849e-01,  1.8674e-01,  6.3842e-01,\n",
      "        -2.0777e-01,  1.1186e-01,  6.5097e-02,  1.1862e-01,  4.3234e-01,\n",
      "         1.4494e+00,  1.1168e+00, -1.4307e+00, -8.9427e-01,  4.8667e-01,\n",
      "         9.3186e-02,  1.1336e-01, -4.0116e-01,  3.0253e-01, -7.0833e-01,\n",
      "         6.0083e-01, -2.0232e+00,  5.9397e-01, -4.0550e-01, -1.3934e+00,\n",
      "        -1.2712e+00, -7.6540e-02, -5.3101e-01,  2.5276e-01,  1.1112e+00,\n",
      "         4.7697e-01, -2.8290e-01, -5.1586e-01, -8.5799e-01, -5.3496e-01,\n",
      "         7.9458e-01,  4.1686e-01,  1.5203e-01, -5.4119e-01, -2.2901e+00,\n",
      "         1.0981e+00,  4.4766e-01, -9.7059e-01, -7.4371e-01, -9.4498e-01,\n",
      "         2.2325e-01,  7.4755e-01,  3.5364e-01,  8.9592e-01, -2.0344e+00,\n",
      "        -3.2824e-02,  1.7540e-01, -3.5914e-01, -1.3401e-01, -6.2748e-01,\n",
      "         1.1811e-01,  3.3400e-01,  4.6222e-01, -1.1327e+00,  5.1820e-01,\n",
      "         1.6201e+00, -2.0616e-02, -2.2441e-01, -1.0510e+00,  1.3830e+00,\n",
      "         1.3011e+00,  1.0649e+00,  5.7804e-01,  1.7399e+00, -4.6978e-01,\n",
      "         2.7187e-02, -9.8177e-01, -5.1153e-01, -1.7599e-01,  1.9228e-01,\n",
      "         3.0473e-01, -1.8559e-01,  8.2180e-01, -7.8858e-01, -1.2671e-01,\n",
      "        -8.6285e-01, -2.1705e-01, -4.2760e-02,  1.5126e-01, -2.5508e-01,\n",
      "         1.2159e-01, -3.8058e-01, -7.4547e-01,  2.1520e-01, -2.6971e-01,\n",
      "         4.7432e-01,  1.8943e-01,  1.7864e-01, -3.4366e-02,  1.8447e-01,\n",
      "        -2.7757e-01, -7.1560e-01,  3.4623e-01,  5.0757e-01, -1.2165e-01,\n",
      "        -1.2417e+00, -4.0641e-01,  8.4456e-01,  2.4402e-02, -7.5676e-01,\n",
      "         7.7638e-01, -1.9039e-01, -3.4463e-01,  6.7558e-01, -3.1009e-01,\n",
      "         7.0305e-01,  3.3839e-01, -7.8816e-02, -9.2660e-01, -5.9607e-01,\n",
      "        -1.7617e+00, -9.4937e-01, -4.7651e-01,  5.6551e-01, -5.0515e-01,\n",
      "        -2.5323e-01,  2.3219e-01, -1.8728e-01,  2.8878e-01, -5.3455e-01,\n",
      "         5.1413e-01,  3.0842e-01, -1.1862e-01,  5.9565e-02, -5.5944e-01,\n",
      "         9.9763e-01, -2.2970e-01, -1.3132e+00])\n"
     ]
    }
   ],
   "source": [
    "phrases = [\"it's raining hard\", \"it is wet outside\",\n",
    "           \"cars drive fast\", \"motorcycles are loud\"]\n",
    "embeddings = transformer.encode(phrases, convert_to_tensor=True)\n",
    "print(\"Number of embeddings:\", len(embeddings))\n",
    "print(\"Dimensions per embedding:\", len(embeddings[0]))\n",
    "print(\"The embedding feature values of \\\"it's raining hard\\\":\")\n",
    "print(embeddings[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Listing 13.7\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The shape of the resulting similarities: torch.Size([4, 4])\n"
     ]
    }
   ],
   "source": [
    "def normalize_embedding(embedding):\n",
    "    normalized = numpy.divide(embedding, numpy.linalg.norm(embedding))\n",
    "    return list(map(float, normalized))\n",
    "\n",
    "normalized_embeddings = list(map(normalize_embedding, embeddings))\n",
    "similarities = sentence_transformers.util.dot_score(normalized_embeddings,\n",
    "                                                    normalized_embeddings)\n",
    "print(\"The shape of the resulting similarities:\", similarities.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Listing 13.8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th>idx</th>\n",
       "      <th>score</th>\n",
       "      <th>phrase a</th>\n",
       "      <th>phrase b</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>0.669060</td>\n",
       "      <td>it's raining hard</td>\n",
       "      <td>it is wet outside</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>0.590783</td>\n",
       "      <td>cars drive fast</td>\n",
       "      <td>motorcycles are loud</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.281166</td>\n",
       "      <td>it's raining hard</td>\n",
       "      <td>cars drive fast</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.280800</td>\n",
       "      <td>it's raining hard</td>\n",
       "      <td>motorcycles are loud</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>0.204867</td>\n",
       "      <td>it is wet outside</td>\n",
       "      <td>motorcycles are loud</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.138172</td>\n",
       "      <td>it is wet outside</td>\n",
       "      <td>cars drive fast</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "def rank_similarities(phrases, similarities):\n",
    "    a_phrases = []\n",
    "    b_phrases = []\n",
    "    scores = []\n",
    "    for a in range(len(similarities) - 1):\n",
    "        for b in range(a + 1, len(similarities)):\n",
    "            a_phrases.append(phrases[a])\n",
    "            b_phrases.append(phrases[b])\n",
    "            scores.append(float(similarities[a][b]))\n",
    "    dataframe = pandas.DataFrame({\"score\": scores,\n",
    "                                  \"phrase a\": a_phrases, \"phrase b\": b_phrases})\n",
    "    dataframe[\"idx\"] = dataframe.index\n",
    "    dataframe = dataframe.reindex(columns=[\"idx\", \"score\", \"phrase a\", \"phrase b\"])\n",
    "    return dataframe.sort_values(by=[\"score\"], ascending=False,\n",
    "                                    ignore_index=True)\n",
    "\n",
    "dataframe = rank_similarities(phrases, similarities)\n",
    "display(HTML(dataframe.to_html(index=False)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Up next: [Natural Language Autocomplete](3.natural-language-autocomplete.ipynb)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
